# ===============================================
# RX 480 / GPU Ceiling Finder — Fully Optimized
# ===============================================
import pyopencl as cl
import numpy as np
import time

# ---------------------------
# Auto-detect GPU device
# ---------------------------
device = None
for platform in cl.get_platforms():
    for d in platform.get_devices():
        if d.type & cl.device_type.GPU:
            device = d
            break
    if device:
        break

if device is None:
    print("[WARN] No GPU found. Falling back to CPU.")
    for platform in cl.get_platforms():
        for d in platform.get_devices():
            if d.type & cl.device_type.CPU:
                device = d
                break
        if device:
            break
    if device is None:
        raise RuntimeError("No OpenCL devices found.")

ctx = cl.Context([device])
queue = cl.CommandQueue(ctx)

print("Using device:", device.name)
print("Global Memory (MB):", device.global_mem_size // 1024**2)
print("Compute Units:", device.max_compute_units)
print("Max Clock (MHz):", device.max_clock_frequency)

# ---------------------------
# Kernel (prismatic recursion with min depth pruning)
# ---------------------------
kernel_code = """
__kernel void recurse_prismatic(
    __global float *data,
    const int expansion_factor,
    const int depth_min)
{
    int gid = get_global_id(0);
    float x = data[gid];

    // Skip trivial recursion below threshold
    for(int i = depth_min; i < expansion_factor; i++){
        x = sqrt(x * 1.618f + 0.5f) * 1.0001f;
    }
    data[gid] = x;
}
"""
program = cl.Program(ctx, kernel_code).build()
kernel = cl.Kernel(program, "recurse_prismatic")  # reuse kernel instance

# ---------------------------
# Recursive expansion helper
# ---------------------------
def expansion(depth):
    return 8**depth

def compute_safe_N(exp_factor, device):
    """Compute a buffer size that stays within 80% of VRAM"""
    max_bytes = int(device.global_mem_size * 0.8)
    N = max_bytes // 4 // exp_factor
    return max(N, 1)

# ---------------------------
# Ceiling finder loop
# ---------------------------
depth_start = 7   # skip low-impact depths
depth_end   = 20  # maximum theoretical depth
for depth in range(depth_start, depth_end + 1):
    exp_factor = expansion(depth)
    N = compute_safe_N(exp_factor, device)

    data = np.random.rand(N).astype(np.float32)
    buf = cl.Buffer(ctx, cl.mem_flags.READ_WRITE | cl.mem_flags.COPY_HOST_PTR, hostbuf=data)

    # set kernel args
    kernel.set_arg(0, buf)
    kernel.set_arg(1, np.int32(exp_factor))
    kernel.set_arg(2, np.int32(depth_start))  # depth_min pruning

    # Warmup
    evt = cl.enqueue_nd_range_kernel(queue, kernel, (N,), None)
    evt.wait()

    # Timed run
    t0 = time.time()
    for _ in range(5):
        evt = cl.enqueue_nd_range_kernel(queue, kernel, (N,), None)
    evt.wait()
    dt = (time.time() - t0) / 5.0

    fps = 1.0 / dt
    vram_mb = data.nbytes / 1024**2
    flops = (N * exp_factor) / dt / 1e9

    print(f"Depth {depth:2d} | N={N:,} | VRAM={vram_mb:.1f} MB | {fps:.2f} FPS | {flops:.2f} GFLOPs")
